import os
import sys
import json
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets
from tqdm import tqdm
from model import resnet50
from tensorboardX import SummaryWriter
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("using {} device.".format(device))
    writer = SummaryWriter(log_dir='logs', flush_secs=60)
    #数据预处理
    data_transform = {
        "train": transforms.Compose([transforms.RandomResizedCrop(224),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
        "val": transforms.Compose([transforms.Resize(256),
                                   transforms.CenterCrop(224),
                                   transforms.ToTensor(),
                                   transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}

    image_path = os.path.join("data")  # flower data set path
    train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),
                                         transform=data_transform["train"])
    train_num = len(train_dataset)
    flower_list = train_dataset.class_to_idx
    cla_dict = dict((val, key) for key, val in flower_list.items())
    # write dict into json file
    json_str = json.dumps(cla_dict, indent=9)#分类数目-1，四类即0 1 2 3所以填3
    with open('class_indices.json', 'w') as json_file:
        json_file.write(json_str)

    batch_size =32#批大小
    nw = 4 # number of workers
    print('Using {} dataloader workers every process'.format(nw))

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size, shuffle=True,
                                               num_workers=nw)

    validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),
                                            transform=data_transform["val"])
    val_num = len(validate_dataset)
    validate_loader = torch.utils.data.DataLoader(validate_dataset,
                                                  batch_size=batch_size, shuffle=False,
                                                  num_workers=nw)

    print("using {} images for training, {} images for validation.".format(train_num,
                                                                           val_num))
    #建立模型
    net = resnet50()
    # # download url: https://download.pytorch.org/models/resnet34-333f7ec4.pth
    model_weight_path = "./resnet50-pre.pth"
    assert os.path.exists(model_weight_path), "file {} does not exist.".format(model_weight_path)
    net.load_state_dict(torch.load(model_weight_path, map_location='cpu'))
    # for param in net.parameters():
    #     param.requires_grad = False
    # change fc layer structure
    in_channel = net.fc.in_features
    net.fc = nn.Linear(in_channel,3)#分类数目
    net.to(device)

    # 定义损失函数
    loss_function = nn.CrossEntropyLoss()
    # 定义优化器
    params = [p for p in net.parameters() if p.requires_grad]
    optimizer = optim.Adam(params, lr=0.0001)

    epochs = 30#训练次数
    best_acc = 0.0
    save_path = './resNet50.pth'#训练好的模型权重
    train_steps = len(train_loader)
    val_steps = len(validate_loader)
    Loss_list = []
    Accuracy_list = []
    Loss_list2 = []
    Accuracy_list2 = []
    #训练过程
    for epoch in range(epochs):
        # train
        net.train()
        running_loss = 0.0
        running_acc = 0.0
        train_bar = tqdm(train_loader, file=sys.stdout)
        for step, data in enumerate(train_bar):
            images, labels = data
            optimizer.zero_grad()
            logits = net(images.to(device))
            loss = loss_function(logits, labels.to(device))
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()
            running_acc += torch.eq(torch.max(net(images.to(device)), dim=1)[1], labels.to(device)).sum().item()
            train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,
                                                                     epochs,
                                                                     loss)
        writer.add_scalar('Train_loss', running_loss / train_steps, epoch)
        writer.add_scalar('Train_acc', running_acc / len(train_dataset), epoch)
        # validate
        net.eval()
        acc = 0.0  # accumulate accurate number / epoch
        epoch_loss = 0.0
        with torch.no_grad():
            val_bar = tqdm(validate_loader, file=sys.stdout)
            for val_data in val_bar:
                val_images, val_labels = val_data
                outputs = net(val_images.to(device))
                # loss = loss_function(outputs, test_labels)
                predict_y = torch.max(outputs, dim=1)[1]
                acc += torch.eq(predict_y, val_labels.to(device)).sum().item()
                epoch_loss += loss.item()
                val_bar.desc = "valid epoch[{}/{}]".format(epoch + 1,
                                                           epochs)

        val_accurate = acc / val_num
        # print('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %
        #       (epoch + 1, running_loss / train_steps, val_accurate))
        val_loss = epoch_loss / len(validate_loader)
        # print('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %(epoch + 1, running_loss / train_steps, val_accurate))
        print('[epoch %d] train_loss: %.3f  train_acc: %.3f  val_loss: %.3f  val_accuracy: %.3f' % (
        epoch + 1, running_loss / train_steps, running_acc / len(train_dataset), val_loss, val_accurate))
        if val_accurate > best_acc:
            best_acc = val_accurate
            torch.save(net.state_dict(), save_path)
        Loss_list.append(running_loss / len(train_loader))
        Accuracy_list.append(val_accurate)
        Loss_list2.append(val_loss)
        Accuracy_list2.append(running_acc / len(train_dataset))
        writer.add_scalar('Val_acc', val_accurate, epoch)
        writer.add_scalar('Val_loss', val_loss, epoch)
    print('Finished Training')


if __name__ == '__main__':
    main()
